import os

import numpy as np
import pickle as pkl
import networkx as nx
import scipy.sparse as sp
from scipy.sparse.linalg import eigsh
import sys
import torch
import torch.nn as nn
from ragraph_utils.complex import Cochain, Complex
from torch_geometric.data import Batch, Data
def parse_skipgram(fname):
    with open(fname) as f:
        toks = list(f.read().split())
    nb_nodes = int(toks[0])
    nb_features = int(toks[1])
    ret = np.empty((nb_nodes, nb_features))
    it = 2
    for i in range(nb_nodes):
        cur_nd = int(toks[it]) - 1
        it += 1
        for j in range(nb_features):
            cur_ft = float(toks[it])
            ret[cur_nd][j] = cur_ft
            it += 1
    return ret

# Process a (subset of) a TU dataset into standard form
import numpy as np
import scipy.sparse as sp
import torch
import networkx as nx
from ragraph_utils.complex import Cochain, Complex
from torch_geometric.data import Batch, Data

def process_tu(data, class_num, node_class, max_ring=6):
    if isinstance(data, Batch):
        data_list = data.to_data_list()
    elif isinstance(data, list):
        data_list = data
    elif isinstance(data, Data):
        data_list = [data]
    else:
        raise ValueError(f"Unsupported data type: {type(data)}")

    labels_list = []
    features_list = []
    all_edges = []
    all_ring_rows = []
    all_ring_cols = []
    offset = 0

    for g in range(len(data_list)):
        x = data_list[g].x
        feature = x[:, range(node_class)].cpu().numpy()
        features_list.append(feature)

        label = np.zeros((class_num,))
        label[data_list[g].y.item()] = 1
        labels_list.append(label)

        # === Edges ===
        e_ind = data_list[g].edge_index.cpu().numpy()
        edges = e_ind.T  # 局部 ID
        edges_list = [tuple(sorted(edge)) for edge in edges.tolist()]

        # === Global edge map ===
        edge_map = {}
        for idx, (u, v) in enumerate(edges_list):
            global_edge = tuple(sorted([u + offset, v + offset]))
            edge_map[global_edge] = idx + len(all_edges)

        # 基于图论中的环空间 找到最小生成环基mcb
        G = nx.Graph()
        G.add_edges_from(edges_list)
        rings = nx.cycle_basis(G)

        for cycle in rings:
            if len(cycle) < 3 or len(cycle) > max_ring:
                continue
            for i in range(len(cycle)):
                u, v = cycle[i], cycle[(i + 1) % len(cycle)]
                global_edge = tuple(sorted([u + offset, v + offset]))
                edge_idx = edge_map.get(global_edge, -1)
                if edge_idx == -1:
                    print(f"未匹配 edge: {global_edge}")
                    continue
                all_ring_rows.append(edge_idx)
                all_ring_cols.append(len(all_ring_cols))

        # === 累积全局边 ===
        all_edges.extend([tuple(sorted([u + offset, v + offset])) for (u, v) in edges_list])
        offset += feature.shape[0]

    # === 拼接 node features 和 labels ===
    features = np.vstack(features_list)
    labels = np.vstack(labels_list)

    # === Block diag adjacency (可选)
    adjacency = sp.block_diag([
        np.asarray(sp.coo_matrix(
            (np.ones(e_ind.shape[1]), (e_ind[0], e_ind[1])),
            shape=(feat.shape[0], feat.shape[0])
        ).todense()) for e_ind, feat in zip(
            [data.edge_index.cpu().numpy() for data in data_list],
            features_list)
    ], format='csr')

    # === Global edge_index ===
    edge_index = torch.tensor(np.array(all_edges).T, dtype=torch.long)

    # === Cochains ===
    v_cochain = Cochain(dim=0, x=torch.FloatTensor(features), y=torch.FloatTensor(labels))
    e_cochain = Cochain(dim=1, x=torch.ones((edge_index.size(1), 1)), boundary_index=edge_index)

    if all_ring_rows:
        two_cell_boundary = torch.tensor([all_ring_rows, all_ring_cols], dtype=torch.long)
        two_cell_cochain = Cochain(dim=2, boundary_index=two_cell_boundary)
    else:
        two_cell_cochain = None

    complex_obj = Complex(v_cochain, e_cochain, two_cell_cochain)

    # =检查 ===
    G_all = nx.Graph()
    G_all.add_edges_from(all_edges)
    all_cycles = nx.cycle_basis(G_all)
    # print(f"[process_tu] 拼接后 nx.cycle_basis: {len(all_cycles)}")
    # print(f"[process_tu] edge_index edges: {edge_index.size(1)}, rings: {torch.unique(torch.tensor(all_ring_cols)).size(0)}")
    # if two_cell_cochain:
    #     print(f" [process_tu] two_cell_boundary shape: {two_cell_boundary.shape}")
    # else:
    #     print(f"[process_tu] 没有有效环，two_cell_cochain=None")
    # === Generate batch vector ===
    # 每个节点标记属于哪个图
    batch_list = []
    offset = 0
    for i, feat in enumerate(features_list):
        n = feat.shape[0]
        batch_list.append(torch.full((n,), i, dtype=torch.long))
    batch = torch.cat(batch_list, dim=0)  # shape: [N_total]
    return features, adjacency, labels, complex_obj, batch



def micro_f1(logits, labels):
    # Compute predictions
    preds = torch.round(nn.Sigmoid()(logits))

    # Cast to avoid trouble
    preds = preds.long()
    labels = labels.long()

    # Count true positives, true negatives, false positives, false negatives
    tp = torch.nonzero(preds * labels).shape[0] * 1.0
    tn = torch.nonzero((preds - 1) * (labels - 1)).shape[0] * 1.0
    fp = torch.nonzero(preds * (labels - 1)).shape[0] * 1.0
    fn = torch.nonzero((preds - 1) * labels).shape[0] * 1.0

    # Compute micro-f1 score
    prec = tp / (tp + fp)
    rec = tp / (tp + fn)
    f1 = (2 * prec * rec) / (prec + rec)
    return f1

"""
 Prepare adjacency matrix by expanding up to a given neighbourhood.
 This will insert loops on every node.
 Finally, the matrix is converted to bias vectors.
 Expected shape: [graph, nodes, nodes]
"""
def adj_to_bias(adj, sizes, nhood=1):
    nb_graphs = adj.shape[0]
    mt = np.empty(adj.shape)
    for g in range(nb_graphs):
        mt[g] = np.eye(adj.shape[1])
        for _ in range(nhood):
            mt[g] = np.matmul(mt[g], (adj[g] + np.eye(adj.shape[1])))
        for i in range(sizes[g]):
            for j in range(sizes[g]):
                if mt[g][i][j] > 0.0:
                    mt[g][i][j] = 1.0
    return -1e9 * (1.0 - mt)


###############################################
# This section of code adapted from tkipf/gcn #
###############################################

def parse_index_file(filename):
    """Parse index file."""
    index = []
    for line in open(filename):
        index.append(int(line.strip()))
    return index

def sample_mask(idx, l):
    """Create mask."""
    mask = np.zeros(l)
    mask[idx] = 1
    return np.array(mask, dtype=np.bool)

def load_data(dataset_str): # {'pubmed', 'citeseer', 'cora'}
    """Load data."""
    current_path = os.path.dirname(__file__)
    names = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph']
    objects = []
    for i in range(len(names)):
        with open("data/ind.{}.{}".format(dataset_str, names[i]), 'rb') as f:
            if sys.version_info > (3, 0):
                objects.append(pkl.load(f, encoding='latin1'))
            else:
                objects.append(pkl.load(f))

    x, y, tx, ty, allx, ally, graph = tuple(objects)
    test_idx_reorder = parse_index_file("data/ind.{}.test.index".format(dataset_str))
    test_idx_range = np.sort(test_idx_reorder)

    if dataset_str == 'citeseer':
        # Fix citeseer dataset (there are some isolated nodes in the graph)
        # Find isolated nodes, add them as zero-vecs into the right position
        test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1)
        tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
        tx_extended[test_idx_range-min(test_idx_range), :] = tx
        tx = tx_extended
        ty_extended = np.zeros((len(test_idx_range_full), y.shape[1]))
        ty_extended[test_idx_range-min(test_idx_range), :] = ty
        ty = ty_extended

    features = sp.vstack((allx, tx)).tolil()
    features[test_idx_reorder, :] = features[test_idx_range, :]
    adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))

    labels = np.vstack((ally, ty))
    labels[test_idx_reorder, :] = labels[test_idx_range, :]

    idx_test = test_idx_range.tolist()
    idx_train = range(len(y))
    idx_val = range(len(y), len(y)+500)

    return adj, features, labels, idx_train, idx_val, idx_test

def sparse_to_tuple(sparse_mx, insert_batch=False):
    """Convert sparse matrix to tuple representation."""
    """Set insert_batch=True if you want to insert a batch dimension."""
    def to_tuple(mx):
        if not sp.isspmatrix_coo(mx):
            mx = mx.tocoo()
        if insert_batch:
            coords = np.vstack((np.zeros(mx.row.shape[0]), mx.row, mx.col)).transpose()
            values = mx.data
            shape = (1,) + mx.shape
        else:
            coords = np.vstack((mx.row, mx.col)).transpose()
            values = mx.data
            shape = mx.shape
        return coords, values, shape

    if isinstance(sparse_mx, list):
        for i in range(len(sparse_mx)):
            sparse_mx[i] = to_tuple(sparse_mx[i])
    else:
        sparse_mx = to_tuple(sparse_mx)

    return sparse_mx

def standardize_data(f, train_mask):
    """Standardize feature matrix and convert to tuple representation"""
    # standardize data
    f = f.todense()
    mu = f[train_mask == True, :].mean(axis=0)
    sigma = f[train_mask == True, :].std(axis=0)
    f = f[:, np.squeeze(np.array(sigma > 0))]
    mu = f[train_mask == True, :].mean(axis=0)
    sigma = f[train_mask == True, :].std(axis=0)
    f = (f - mu) / sigma
    return f

def preprocess_features(features):
    """Row-normalize feature matrix and convert to tuple representation"""
    rowsum = np.array(features.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    features = r_mat_inv.dot(features)
    return features.todense(), sparse_to_tuple(features)

def normalize_adj(adj):
    """Symmetrically normalize adjacency matrix."""
    adj = sp.coo_matrix(adj)
    rowsum = np.array(adj.sum(1))
    d_inv_sqrt = np.power(rowsum, -0.5).flatten()
    d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
    d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
    return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()


def preprocess_adj(adj):
    """Preprocessing of adjacency matrix for simple GCN model and conversion to tuple representation."""
    adj_normalized = normalize_adj(adj + sp.eye(adj.shape[0]))
    return sparse_to_tuple(adj_normalized)

def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)




